{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = nn.Conv2d(2, 1, (2, 2), stride=1, padding=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将卷积计算的模型参数设置为特定的值\n",
    "k1 = torch.tensor(\n",
    "    [[1, -2],\n",
    "     [2, 0]]\n",
    ").float()\n",
    "k2 = torch.tensor(\n",
    "    [[-3, 4],\n",
    "    [5, 3]]\n",
    ").float()\n",
    "kernel = torch.stack((k1, k2), dim=0).unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.weight = nn.Parameter(kernel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "m.bias = nn.Parameter(torch.tensor([0]).float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对特定数据进行卷积计算\n",
    "i1 = torch.tensor(\n",
    "    [[1, 2, 3],\n",
    "     [0, 2, -1],\n",
    "     [2, 3, 1]]\n",
    ").float()\n",
    "i2 = torch.tensor(\n",
    "    [[1, 3, -2],\n",
    "     [2, 4, 0],\n",
    "     [3, 3, 1]]\n",
    ").float()\n",
    "image = torch.stack((i1, i2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[ 1.,  2.,  3.],\n",
       "          [ 0.,  2., -1.],\n",
       "          [ 2.,  3.,  1.]],\n",
       " \n",
       "         [[ 1.,  3., -2.],\n",
       "          [ 2.,  4.,  0.],\n",
       "          [ 3.,  3.,  1.]]]),\n",
       " tensor([[[28.,  3.],\n",
       "          [34., 16.]]], grad_fn=<SqueezeBackward1>))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image, m(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Conv2dSimple:\n",
    "    \n",
    "    def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0):\n",
    "        '''\n",
    "        定义二维卷积计算\n",
    "        参数\n",
    "        ----\n",
    "        in_channel ：int，输入通道\n",
    "        out_channel ：int，输出通道\n",
    "        kernel_size ：tuple(int, int)，本地感受野的大小\n",
    "        stride ：int，步幅大小\n",
    "        padding ：int，填充长度\n",
    "        '''\n",
    "        self.stride = stride\n",
    "        self.padding = padding\n",
    "        self.out_channel = out_channel\n",
    "        self.kernel_size = kernel_size\n",
    "        self.weight = torch.randn((out_channel, in_channel) + kernel_size)\n",
    "        self.bias = torch.randn(out_channel)\n",
    "        \n",
    "    def __call__(self, x):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        x ：torch.FloatTensor，形状为(I, H, W)或者(B, I, H, W)\n",
    "        '''\n",
    "        data = x.unsqueeze(0) if len(x.shape) == 3 else x\n",
    "        # B表示批量大小，I表示in_channel个数，H表示高度，W表示宽度\n",
    "        B, I, H, W = data.shape\n",
    "        h_step = (H + 2 * self.padding - self.kernel_size[0]) // self.stride + 1\n",
    "        w_step = (W + 2 * self.padding - self.kernel_size[1]) // self.stride + 1\n",
    "        # 定义输出的形状\n",
    "        output = torch.zeros(B, self.out_channel, h_step, w_step)\n",
    "        # 在图像的边缘增加0\n",
    "        data = F.pad(data, (self.padding,) * 4)\n",
    "        for i in range(h_step):\n",
    "            h_begin = i * self.stride\n",
    "            h_end = h_begin + self.kernel_size[0]\n",
    "            for j in range(w_step):\n",
    "                w_being = j * self.stride\n",
    "                w_end = w_being + self.kernel_size[1]\n",
    "                inputs = data[:, :, h_begin: h_end, w_being: w_end].unsqueeze(1)\n",
    "                # inputs的形状：     （B,           1, I, kernel_size[0], kernel_size[1]）\n",
    "                # self.weight的形状：（   out_channel, I, kernel_size[0], kernel_size[1]）\n",
    "                # linear_out的形状：  (B, out_channel）\n",
    "                # bias的形状：        (   out_channel）\n",
    "                linear_out = (inputs * self.weight).sum((-3, -2, -1))\n",
    "                output[:, :, i, j] = linear_out + self.bias\n",
    "        self.out = output.squeeze(0) if len(x.shape) == 3 else output\n",
    "        return self.out\n",
    "    \n",
    "    def parameters(self):\n",
    "        return [self.weight, self.bias]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 验证实现是否正确\n",
    "t = Conv2dSimple(2, 1, (2, 2), stride=1, padding=0)\n",
    "t.weight = nn.Parameter(kernel)\n",
    "t.bias = nn.Parameter(torch.tensor([0]).float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[ 1.,  2.,  3.],\n",
       "          [ 0.,  2., -1.],\n",
       "          [ 2.,  3.,  1.]],\n",
       " \n",
       "         [[ 1.,  3., -2.],\n",
       "          [ 2.,  4.,  0.],\n",
       "          [ 3.,  3.,  1.]]]),\n",
       " tensor([[[28.,  3.],\n",
       "          [34., 16.]]], grad_fn=<SqueezeBackward1>))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image, t(image)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
